第 7 课:序列模型
之前的几课中,我们学习了神经网络的实现方式,但是如何实现序列模型结构呢?
序列模型是处理时间序列、文本、语音等顺序数据的核心工具,常见模型包括:
| 模型 | 简介 | 适用场景 |
|---|---|---|
| RNN (Recurrent Neural Network) | 最基本的循环网络,能记住前面的状态 | 简单的时间序列、短文本 |
| LSTM (Long Short-Term Memory) | 改进版 RNN,引入“记忆门控”机制,能处理长期依赖 | NLP、语音识别 |
| GRU (Gated Recurrent Unit) | 类似 LSTM,但结构更简单 | 资源有限的设备 |
| Transformer | 并行处理序列,不依赖循环结构,靠“注意力机制”捕捉长依赖 | BERT, GPT 等 NLP 模型基础 |
| Temporal Convolutional Network (TCN) | 用卷积代替循环网络,提升速度和稳定性 | 时间序列预测 |
| Seq2Seq(编码器-解码器结构) | 输入输出均为序列,如机器翻译 | 翻译、摘要、对话系统 |
我们会从 RNN → LSTM/GRU → Transformer 的路径逐步展开。
注:NLP 是自然语言处理
(一)RNN
目标功能:用字符序列预测下一个字符,比如 “hell” → “o”,即训练模型生成句子或单词。
推荐数据集:tiny Shakespeare(常用RNN测试数据集,只有几百 KB,来自莎士比亚作品片段 karpathy/char-rnn)
以 "hello" 为例,训练模型让它学会从 "h" → "e" → "l" → "l" → "o":
序列模型仍然需要 DataSet 和 DataLoader
项目结构
char_rnn_project/
│
├── char_rnn.ipynb # 主 Jupyter Notebook
└── data/
└── tinyshakespeare.txt # 数据文件(我们会下载)
步骤一: 下载数据
我们用 Python 自动下载 karpathy/char-rnn 的文本:
import os
import requests
os.makedirs("data", exist_ok=True)
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
save_path = "data/tinyshakespeare.txt"
if not os.path.exists(save_path):
response = requests.get(url)
with open(save_path, "w", encoding="utf-8") as f:
f.write(response.text)
print("数据已保存到 data/tinyshakespeare.txt")
步骤二: 数据预处理
我们把所有字符转换为整数索引,生成序列对(如 "hell" → "ello"):
读取数据
with open("data/tinyshakespeare.txt", "r", encoding="utf-8") as f:
text = f.read()
建立词汇表
# 建立字符到整数的映射
print(set(text))
chars = sorted(list(set(text))) # set(text) 提取 text 中所有不同的字符,用集合去重。
print(chars)
vocab_size = len(chars)
print(f"共 {vocab_size} 个唯一字符")
{'!', 'U', 'c', '?', ':', 'i', 'W', 'x', 'R', 'g', 'M', 'H', 'B', 'N', 'm', 'A', 'S', 'l', 'F', ' ', 'o', '\n', 'k', 'P', '3', 'C', 'Y', ',', 'D', 'y', 'r', 'e', ';', 's', 'J', '&', 'b', 'a', '-', 'n', 'u', 'L', 'O', 'E', 'G', 'X', 'p', 't', '.', 'Z', 'K', 'f', 'I', 'z', 'j', '$', 'w', 'v', "'", 'V', 'Q', 'q', 'h', 'd', 'T'}
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
共 65 个唯一字符
# 构建一个字典,把每个字符映射到一个整数编号。
stoi = {ch: i for i, ch in enumerate(chars)} # stoi 是 "string to index" 的缩写。
itos = {i: ch for ch, i in stoi.items()} # itos 是 "index to string" 的缩写。
设计把 string 转换为整数列表的函数 encode,以及把整数列表转换回 string 的函数 decode。
# 编码函数
def encode(s):
return [stoi[c] for c in s]
def decode(l):
return ''.join([itos[i] for i in l])